{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from matplotlib import cm\n",
    "Blues = cm.get_cmap('Blues', 12)\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "import timeit\n",
    "import math\n",
    "import scipy\n",
    "import scipy.linalg\n",
    "import numpy.linalg\n",
    "import scipy.special\n",
    "import scipy.sparse\n",
    "import scipy.sparse.linalg as ssla\n",
    "import scipy.optimize\n",
    "from mpmath import mp\n",
    "from datetime import date\n",
    "today = date.today()\n",
    "date = today.strftime(\"%b%d%Y\")\n",
    "\n",
    "from sympy import *\n",
    "from sympy.vector import CoordSys3D\n",
    "\n",
    "hbar= 6.582119569509e-13 #meV*second\n",
    "electronMass= 5.68563006e-27 # meV * (second/nm)^2\n",
    "eSquaredOvere = 1439.964547  # meV * nm\n",
    "\n",
    "mpl.rcParams['pdf.fonttype'] = 42\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "mpl.rcParams['font.family'] = 'Arial'\n",
    "np.set_printoptions(precision=10,suppress=True,linewidth=1000)\n",
    "from numba import njit\n",
    "import numba\n",
    "figDir='/Users/aidanreddy/Desktop/Research/plots/wignermolecule/'\n",
    "dataDir='/Users/aidanreddy/Desktop/Research/data/wignermoleculedata/'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$V(\\mathbf{r})=\\alpha_2r^2+\\alpha_3\\cos(3\\theta)r^3+... \\\\\n",
    "=\\beta_2(r/l)^2+\\beta_3\\cos(3\\theta)(r/l)^3+...$$\n",
    "\n",
    "$$\\mathcal{H}=\\hbar\\omega(a_+^{\\dag}a_+ +a_-^{\\dag}a_-+1+\\gamma\\frac{[(a_+^{\\dag}+a_-)^3+(a_-^{\\dag}+a_+)^3]}{2^{5/2}})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "WORKHORSE FUNCTIONS\n",
    "\"\"\"\n",
    "\n",
    "def GenDiagEnergy(ns,ls):\n",
    "    return np.sum((2*ns+np.abs(ls)+1),axis=1)\n",
    "\n",
    "@njit\n",
    "def factorial(n): #for numba compatibility\n",
    "    x = 1.0\n",
    "    for m in range(1,n+1):\n",
    "        x *= m\n",
    "    return x\n",
    "\n",
    "#@njit\n",
    "def fockDarwin(n,l,rhoThetaVals):\n",
    "    N=np.sqrt(math.factorial(n)/(np.pi*math.factorial(n+abs(l))))\n",
    "    rhoVals=abs(rhoThetaVals[:,0])\n",
    "    thetaVals=rhoThetaVals[:,1]\n",
    "    R=rhoVals**(abs(l))*np.exp(-rhoVals**2/2)*scipy.special.genlaguerre(n,abs(l))(rhoVals**2)*(-1)**n #NOTE THAT THIS DEFINITION DIFFERS FROM THE CONVENTIONAL ONE IN THE LITERATURE BY THE FACTOR OF (-1)^n\n",
    "    Phi=np.exp(+1j*l*thetaVals)\n",
    "    FD=N*np.einsum('i,i->i',R,Phi)\n",
    "    return(FD)\n",
    "    \n",
    "#@njit   \n",
    "def natOrbWF(rhoThetaVals,natOrb,spBasis):\n",
    "    nSp=spBasis.shape[0]\n",
    "    psi=0\n",
    "    for i in range(nSp):\n",
    "        if abs(natOrb[i])>=1e-10:\n",
    "            n,l=spBasis[i]\n",
    "            psi+=natOrb[i]*fockDarwin(n,l,rhoThetaVals)\n",
    "    return(psi)\n",
    "\n",
    "\n",
    "@njit\n",
    "def CME(n1pi,n1mi,n2pi,n2mi,n1pj,n1mj,n2pj,n2mj): # coulomb matrix element\n",
    "    nsum = n1pi+n1mi+n2pi+n2mi+n1pj+n1mj+n2pj+n2mj\n",
    "    dl = (n1pi+n2pi-n1mi-n2mi)-(n1pj+n2pj-n1mj-n2mj)\n",
    "    deltaN2 = (n2pi+n2mi)-(n2pj+n2mj)\n",
    "    deltaN1= (n1pi+n1mi)-(n1pj+n1mj)\n",
    "    deltaNtot=deltaN2+deltaN1\n",
    "    if dl!=0:\n",
    "        return 0 #angular momentum conservation\n",
    "    S1p = 0 #sum over k1p\n",
    "    a1p = 1/factorial(n1pi)/factorial(n1pj) #prefactor of each term in S1p\n",
    "    for k1p in range(min(n1pi,n1pj)+1):\n",
    "        S1m = 0 #sum over k1m\n",
    "        a1m = 1/factorial(n1mi)/factorial(n1mj)\n",
    "        for k1m in range(min(n1mi,n1mj)+1):\n",
    "            S2p = 0\n",
    "            a2p = 1/factorial(n2pi)/factorial(n2pj)\n",
    "            for k2p in range(min(n2pi,n2pj)+1):\n",
    "                S2m = 0\n",
    "                a2m = 1/factorial(n2mi)/factorial(n2mj)\n",
    "                for k2m in range(min(n2mi,n2mj)+1):\n",
    "                    p = nsum-2*(k1p+k1m+k2p+k2m)\n",
    "                    I = 2**(-(p+3)/2)*math.gamma((p+1)/2)\n",
    "                    S2m += a2m*I\n",
    "                    a2m *= -(n2mi-k2m)*(n2mj-k2m)/(k2m+1)\n",
    "                S2p += a2p*S2m\n",
    "                a2p *= -(n2pi-k2p)*(n2pj-k2p)/(k2p+1)\n",
    "            S1m += a1m*S2p\n",
    "            a1m *= -(n1mi-k1m)*(n1mj-k1m)/(k1m+1)\n",
    "        S1p += a1p*S1m\n",
    "        a1p *= -(n1pi-k1p)*(n1pj-k1p)/(k1p+1)\n",
    "    Vij = S1p*2*(-1)**(abs(deltaN2))*(1j)**(deltaNtot)*((-1)**abs(n1pj+n1mj+n2pj+n2mj))*np.sqrt(factorial(n1pi)*factorial(n1pj)*factorial(n1mi)*factorial(n1mj)*factorial(n2pi)*factorial(n2pj)*factorial(n2mi)*factorial(n2mj))\n",
    "    return Vij\n",
    "\n",
    "def generateCMatrixFDBasis(spBasis):\n",
    "    nSp=spBasis.shape[0]\n",
    "    CMat=np.zeros((nSp,nSp,nSp,nSp),dtype='complex')\n",
    "    for i, si in enumerate(spBasis):\n",
    "        ni,li=si\n",
    "        for j, sj in enumerate(spBasis):\n",
    "            nj,lj=sj\n",
    "            for k, sk in enumerate(spBasis):\n",
    "                nk,lk=sk\n",
    "                for l, sl in enumerate(spBasis):\n",
    "                    nl,ll=sl\n",
    "                    CMat[i,j,k,l]+= AidanV(ni,li,nj,lj,nk,lk,nl,ll) # VMatrixElm AidanV(ni,li,nj,lj,nk,lk,nl,ll) #change here to decide between my or Trithep's CME function\n",
    "    return(CMat)\n",
    "\n",
    "def generateCMatrixAltBasis(AltBasis,CMatFDBasis):\n",
    "    CMatAltBasis=np.einsum('im,jn,ko,lp,ijkl->mnop',np.conj(AltBasis),np.conj(AltBasis),AltBasis,AltBasis,CMatFDBasis)\n",
    "    return(CMatAltBasis)\n",
    "\n",
    "def AidanV(n1_prime,l1_prime,n2_prime,l2_prime,n1,l1,n2,l2):\n",
    "    n1p_prime = n1_prime+(abs(l1_prime)+l1_prime)//2\n",
    "    n1m_prime = n1p_prime-l1_prime\n",
    "    n2p_prime = n2_prime+(abs(l2_prime)+l2_prime)//2\n",
    "    n2m_prime = n2p_prime-l2_prime\n",
    "    n1p = n1+(abs(l1)+l1)//2\n",
    "    n1m = n1p-l1\n",
    "    n2p = n2+(abs(l2)+l2)//2\n",
    "    n2m = n2p-l2\n",
    "    return CME(n1p_prime,n1m_prime,n2p_prime,n2m_prime,n1p,n1m,n2p,n2m)\n",
    "\n",
    "def GenerateFDStates(N): # my version\n",
    "    numBasisStates = int((N+1)*(N+2)/2)\n",
    "    stateList = np.zeros((numBasisStates,2))\n",
    "    index = 0\n",
    "    for nP in range(N+1):\n",
    "        for nM in range(N-nP+1):\n",
    "            n=min(nP,nM)\n",
    "            l=nP-nM\n",
    "            stateList[index] = np.array([n,l])\n",
    "            index += 1\n",
    "    E0=2*stateList[:,0]+abs(stateList[:,1])\n",
    "    L=stateList[:,1]\n",
    "    ind_sort = np.argsort(E0+0.1*L/N)\n",
    "    E0_sorted=E0[ind_sort] #sort primarily by energy and secondarily by orbital angular momentum\n",
    "    stateList_sorted=stateList[ind_sort]\n",
    "    #sort primarily by energy and then secondarily by angular momentum\n",
    "    numStates=len(stateList)\n",
    "    return(stateList_sorted.astype(int))\n",
    "\n",
    "def GenerateFockStates(OneBodyStates,Nup,Ndn,L,mod=1000):\n",
    "    from itertools import combinations\n",
    "    ls = OneBodyStates[:,1]\n",
    "    nstates = len(ls)\n",
    "    states = []\n",
    "    for UpStates in combinations(range(nstates),Nup):\n",
    "        l_up = sum((ls[i] for i in UpStates))\n",
    "        for DnStates in combinations(range(nstates,2*nstates),Ndn):\n",
    "            l_dn = sum((ls[i-nstates] for i in DnStates))\n",
    "            if (l_up+l_dn)%mod==L%mod: \n",
    "                states.append(list(UpStates)+list(DnStates))\n",
    "    states = np.array(states)\n",
    "    return states\n",
    "\n",
    "def generateAltFockBasis(nSp,NU,ND):\n",
    "    from itertools import combinations\n",
    "    states = []\n",
    "    for UpStates in combinations(range(nSp),NU):\n",
    "        for DnStates in combinations(range(nSp,2*nSp),ND):\n",
    "            states.append(list(UpStates)+list(DnStates))\n",
    "    states = np.array(states)\n",
    "    return states\n",
    "\n",
    "def generateHspMatFD(fdBasis):\n",
    "    Hsp0MatFD=2*fdBasis[:,0]+np.abs(fdBasis[:,1])+1\n",
    "    HspMatFD=np.diag(Hsp0MatFD)\n",
    "    return(HspMatFD)\n",
    "\n",
    "def generateH0(nSp,fockBasis,Hsp,oneParticleDiffDict):\n",
    "    nF=fockBasis.shape[0]\n",
    "    H0=np.zeros((nF,nF),dtype='complex')\n",
    "    for i, fi in enumerate(fockBasis):\n",
    "        for k in fi:\n",
    "            H0[i,i]+=Hsp[k%nSp,k%nSp]\n",
    "    for i in oneParticleDiffDict:\n",
    "        fi=fockBasis[i]\n",
    "        for j in oneParticleDiffDict[i]:\n",
    "            fj=fockBasis[j]\n",
    "            spa=set(fi).difference(set(fj)).pop() \n",
    "            spb=set(fj).difference(set(fi)).pop()\n",
    "            #print('spa,spb:', spa,spb)\n",
    "            if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin\n",
    "                if spa < spb:\n",
    "                    betweenCount=np.count_nonzero((fj>spa)&(fj<spb))\n",
    "                elif spb < spa:\n",
    "                    betweenCount=np.count_nonzero((fj>spb)&(fj<spa))\n",
    "                sign=(-1)**(betweenCount)\n",
    "                #print('apple:', i,j,Hsp[spa%nSp,spb%nSp]*sign)\n",
    "                H0[i,j]+=Hsp[spa%nSp,spb%nSp]*sign\n",
    "    #print('H0:\\n', H0)\n",
    "    return H0\n",
    "\n",
    "@njit\n",
    "def kd(a,b):\n",
    "    if a==b:\n",
    "        return(1)\n",
    "    else:\n",
    "        return(0)\n",
    "\n",
    "def calcap(fdBasis):\n",
    "    nSp=fdBasis.shape[0]\n",
    "    fdBasisnpnm=np.zeros((nSp,2),dtype=float)\n",
    "    fdBasisnpnm[:,0]+=fdBasis[:,0]+(fdBasis[:,1]+abs(fdBasis[:,1]))/2\n",
    "    fdBasisnpnm[:,1]+=fdBasis[:,0]+(-fdBasis[:,1]+abs(fdBasis[:,1]))/2\n",
    "    ap=np.zeros((nSp,nSp),dtype=float)\n",
    "    for i, si in enumerate(fdBasisnpnm):\n",
    "            npi=si[0]\n",
    "            nmi=si[1]\n",
    "            for j, sj in enumerate(fdBasisnpnm):\n",
    "                npj=sj[0]\n",
    "                nmj=sj[1]\n",
    "                ap[i,j]+=(kd(npi,npj-1)*kd(nmi,nmj)*np.sqrt(npj))\n",
    "    return(ap)\n",
    "\n",
    "def dag(O):\n",
    "    return np.conjugate(np.transpose(O))\n",
    "\n",
    "def calcam(fdBasis):\n",
    "    nSp=fdBasis.shape[0]\n",
    "    fdBasisnpnm=np.zeros((nSp,2),dtype=float)\n",
    "    fdBasisnpnm[:,0]+=fdBasis[:,0]+(fdBasis[:,1]+abs(fdBasis[:,1]))/2\n",
    "    fdBasisnpnm[:,1]+=fdBasis[:,0]+(-fdBasis[:,1]+abs(fdBasis[:,1]))/2\n",
    "    am=np.zeros((nSp,nSp),dtype=float)\n",
    "    #print('fdBasis (n,l), (np,nm):\\n', np.stack((fdBasis,fdBasisnpnm),axis=1))\n",
    "    for i, si in enumerate(fdBasisnpnm):\n",
    "            npi=si[0]\n",
    "            nmi=si[1]\n",
    "            for j, sj in enumerate(fdBasisnpnm):\n",
    "                npj=sj[0]\n",
    "                nmj=sj[1]\n",
    "                am[i,j]+=(kd(npi,npj)*kd(nmi,nmj-1)*np.sqrt(nmj))\n",
    "    return(am)\n",
    "\n",
    "\n",
    "def generateHspAltBasis(AltBasis, HspFDBasis):\n",
    "    HspAltBasis=np.einsum('ij,jk,kl->il',np.conj(AltBasis).T,HspFDBasis,AltBasis)\n",
    "    return(HspAltBasis)\n",
    "\n",
    "def canonical_interaction(num_1p,i,j,k,l):\n",
    "    i=i%num_1p\n",
    "    j=j%num_1p    \n",
    "    k=k%num_1p\n",
    "    l=l%num_1p\n",
    "    return min(*[(i,j,k,l),(j,i,l,k),(k,l,i,j),(l,k,j,i)])\n",
    "\n",
    "def GenerateHint(FockStates,CMat):\n",
    "    from collections import defaultdict\n",
    "    combos = defaultdict(list)\n",
    "    nstates = len(FockStates)\n",
    "    num_1pstates = CMat.shape[0]\n",
    "    StateVectors= np.zeros([FockStates.shape[0],nSp*2],dtype=int)\n",
    "    StateIndex = dict()\n",
    "    oneParticleDiffDict=dict()\n",
    "    for i in range(nstates):\n",
    "        StateIndex[tuple(FockStates[i])]=i\n",
    "        StateVectors[i][FockStates[i]]=1\n",
    "    for i in range(nstates):\n",
    "        FockState = np.array(FockStates[i])\n",
    "        for a in range(len(FockState)):\n",
    "            for b in range(a+1,len(FockState)):\n",
    "                alpha = FockState[a]\n",
    "                beta = FockState[b]\n",
    "                SameSpin = (beta<num_1pstates and alpha<num_1pstates) or (beta>=num_1pstates and alpha>=num_1pstates)\n",
    "                if SameSpin:\n",
    "                    direct = (alpha,beta,alpha,beta)\n",
    "                    exchange = (beta,alpha,alpha,beta)\n",
    "                    combos[canonical_interaction(num_1pstates,*direct)].append((1,(i,i)))\n",
    "                    combos[canonical_interaction(num_1pstates,*exchange)].append((-1,(i,i)))\n",
    "                else:\n",
    "                    direct = (alpha,beta,alpha,beta)\n",
    "                    combos[canonical_interaction(num_1pstates,*direct)].append((1,(i,i)))\n",
    "        i_StateVector = StateVectors[i]\n",
    "        delta = StateVectors[:,:] - i_StateVector[None,:]\n",
    "        num_diffs = (np.sum(delta!=0,axis=1))\n",
    "        one_particle_diff = np.where(num_diffs==2)[0]\n",
    "        oneParticleDiffDict[i]=one_particle_diff\n",
    "        two_particle_diff = np.where(num_diffs==4)[0]\n",
    "        for j in one_particle_diff:\n",
    "            j_StateVector = StateVectors[j]\n",
    "            alpha = np.where(delta[j]==-1)[0][0]\n",
    "            alpha_prime = np.where(delta[j]==1)[0][0]\n",
    "            for beta in FockState:\n",
    "                if beta == alpha: continue\n",
    "                SameSpin = (beta<num_1pstates and alpha<num_1pstates) or (beta>=num_1pstates and alpha>=num_1pstates)\n",
    "                if SameSpin:\n",
    "                    direct = (alpha_prime,beta,alpha,beta)\n",
    "                    exchange = (beta,alpha_prime,alpha,beta)\n",
    "                    sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])))\n",
    "                    combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))\n",
    "                    combos[canonical_interaction(num_1pstates,*exchange)].append((-sign,(j,i)))\n",
    "                else:\n",
    "                    direct = (alpha_prime,beta,alpha,beta)\n",
    "                    sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])))\n",
    "                    combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))\n",
    "        for j in two_particle_diff:\n",
    "            j_StateVector = StateVectors[j]\n",
    "            alpha,beta = np.where(delta[j]==-1)[0]\n",
    "            alpha_prime,beta_prime = np.where(delta[j]==1)[0]\n",
    "            SameSpin = (alpha < num_1pstates and beta < num_1pstates) or (alpha >= num_1pstates and beta >= num_1pstates)\n",
    "            if SameSpin:\n",
    "                direct = (alpha_prime,beta_prime,alpha,beta)\n",
    "                exchange = (beta_prime,alpha_prime,alpha,beta)\n",
    "                sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])\n",
    "                                 +np.sum(i_StateVector[:beta])-np.sum(j_StateVector[:beta_prime])))\n",
    "                combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))\n",
    "                combos[canonical_interaction(num_1pstates,*exchange)].append((-sign,(j,i)))\n",
    "            else:\n",
    "                direct = (alpha_prime,beta_prime,alpha,beta)\n",
    "                sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])\n",
    "                                 +np.sum(i_StateVector[:beta])-np.sum(j_StateVector[:beta_prime])))\n",
    "                combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))\n",
    "    Hint = np.zeros([FockStates.shape[0]]*2, dtype='complex')\n",
    "    for stateindices,sign_jis in combos.items():\n",
    "        a,b,c,d = stateindices\n",
    "        v=CMat[a,b,c,d]\n",
    "        for sign,(j,i) in sign_jis:\n",
    "            Hint[j,i] += sign*v\n",
    "    return Hint, oneParticleDiffDict\n",
    "    \n",
    "\n",
    "def oneBodyDMat(eState,oneParticleDiffDict,spBasis,fBasis):\n",
    "    nSp=spBasis.shape[0]\n",
    "    D=np.zeros((2,nSp,nSp),dtype='complex')\n",
    "    for i,fi in enumerate(fBasis):\n",
    "        #diagonal in fock basis\n",
    "        for sp in fi:\n",
    "            s=abs(sp//nSp)\n",
    "            D[s,sp%nSp,sp%nSp]+=np.conj(eState[i])*eState[i]\n",
    "    for i in oneParticleDiffDict:\n",
    "        fi=fBasis[i]\n",
    "        #off-diagonal in fock basis\n",
    "        for j in oneParticleDiffDict[i]:\n",
    "            fj=fBasis[j]\n",
    "            spa=set(fi).difference(set(fj)).pop() \n",
    "            spb=set(fj).difference(set(fi)).pop()\n",
    "            if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin\n",
    "                if spa < spb:\n",
    "                    betweenCount=np.count_nonzero((fj>spa)&(fj<spb))\n",
    "                elif spb < spa:\n",
    "                    betweenCount=np.count_nonzero((fj>spb)&(fj<spa))\n",
    "                sign=(-1)**(betweenCount)\n",
    "                s=abs(spa//nSp)\n",
    "                D[s,spa%nSp,spb%nSp]+=np.conj(eState[j])*eState[i]*sign\n",
    "    return(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Strained moiré Wigner molecule\n",
    "\"\"\"\n",
    "\n",
    "phi = 20.0 * np.pi/180\n",
    "V = 9.65/np.cos(phi) #meV\n",
    "eps = 10 #5\n",
    "mStar = 0.9\n",
    "aM = 9.8 #nm\n",
    "L = 0 # orbital angular momentum of state (mod rotation symmetry breaking)\n",
    "\n",
    "N = 5 #9 #maximum shell in single particle basis\n",
    "\n",
    "nTaylor = 6 # go to order r^(nTaylor)\n",
    "\n",
    "stretchfactor = 0.0\n",
    "stretchangle = 0 #degrees\n",
    "\n",
    "if stretchfactor == 0:\n",
    "    mod = 3\n",
    "else:\n",
    "    mod = 1 #angular momentum not conserved for generic strain\n",
    "\n",
    "NU = 4\n",
    "ND = 0\n",
    "\n",
    "a1= aM*np.array([1.0, 0, 0]) #np.array([11.04, 0, 0]) #LONG AM REGION\n",
    "a2= aM*np.array([1/2, np.sqrt(3)/2, 0]) #np.array([3.411, 7.883, 0])\n",
    "theta=stretchangle*(np.pi/180)\n",
    "stretchMat=np.array([[1+stretchfactor/2,0,0],[0,1-stretchfactor/2,0],[0,0,0]])\n",
    "rotMat = np.array([[np.cos(theta), -np.sin(theta), 0],[np.sin(theta), np.cos(theta), 0],[0,0,1]])\n",
    "rotInv =  np.array([[np.cos(-theta), -np.sin(-theta), 0],[np.sin(-theta), np.cos(-theta), 0],[0,0,1]])\n",
    "lintrans = rotMat @ stretchMat @ rotInv\n",
    "a1 = lintrans @ a1\n",
    "a2 = lintrans @ a2\n",
    "\n",
    "ez = np.array([0,0,1])\n",
    "\n",
    "A = np.cross(a1, a2)[2]\n",
    "\n",
    "b1 = 2*np.pi*(np.cross(a2, ez)) / A\n",
    "b2 = 2*np.pi*(np.cross(ez, a1)) / A\n",
    "\n",
    "(b1x, b1y) = b1[:2]\n",
    "\n",
    "(b2x, b2y) = b2[:2]\n",
    "\n",
    "print((b1x, b1y))\n",
    "print((b2x, b2y))\n",
    "\n",
    "print(np.sqrt(b2x**2 + b2y**2)/(4*np.pi / (np.sqrt(3))))\n",
    "\n",
    "#aM = np.sqrt(A*2/np.sqrt(3)) #4*np.pi / (np.sqrt(3)) / ((np.sqrt(b1x**2 + b1y**2) + np.sqrt(b2x**2 + b2y**2)) / 2)\n",
    "\n",
    "print(\"aM:\", aM)\n",
    "\n",
    "#useCFBasis==True\n",
    "\n",
    "paramString='WigMolNU%dND%dN%dV=%.2fmStar=%.2faM=%.2fphi=%.2fb1=(%.2f, %.2f)b2=(%.2f,%.2f)ntaylor%dstretchfactor%.2fangle%.2f%s'%(NU,ND,N,V,mStar,aM,phi*(180/np.pi),b1x,b1y,b2x,b2y,nTaylor,stretchfactor,stretchangle,date)\n",
    "\n",
    "FDBasis=GenerateFDStates(N)\n",
    "\n",
    "nSp=FDBasis.shape[0]\n",
    "print('nSp:', nSp)\n",
    "\n",
    "tic=timeit.default_timer()\n",
    "FDBasisEnlarged=GenerateFDStates(N+2*nTaylor) #enlarge basis to allow for higher order matrix elements to be treated properly\n",
    "for i  in range(nSp):\n",
    "    if np.array_equal(FDBasis[i],FDBasisEnlarged[i])==False:\n",
    "        print('enlarged and truncated FB basis mismatch!')\n",
    "\n",
    "#SET UP CRYSTAL FIELD\n",
    "\n",
    "vec = CoordSys3D('vec')\n",
    "\n",
    "#Definging operators\n",
    "ap=calcap(FDBasisEnlarged)\n",
    "am=calcam(FDBasisEnlarged)\n",
    "rp=ap+dag(am)\n",
    "rm=dag(ap)+am\n",
    "pp=(ap-dag(am))/2j\n",
    "pm=(am-dag(ap))/2j\n",
    "x=((rp+rm)/2).astype('complex')\n",
    "y=(rp-rm)/(2j)\n",
    "\n",
    "HspKineticMatFD = (4*pm@pp)/2\n",
    "\n",
    "B1x, B1y, B2x, B2y, Phi, X, Y = symbols('B1x B1y B2x B2y Phi X Y')\n",
    "\n",
    "rCart = X*vec.i +  Y*vec.j\n",
    "B1 = B1x*vec.i + B1y*vec.j\n",
    "B2 = B2x*vec.i + B2y*vec.j\n",
    "B3 = -(B1 + B2)\n",
    "\n",
    "mPCart = -2 * V * (cos(rCart.dot(B1) + Phi) + cos(rCart.dot(B2) + Phi) + cos(rCart.dot(B3) + Phi))\n",
    "\n",
    "HspCFMatFD = np.zeros_like(x)\n",
    "\n",
    "#x axis spring constant\n",
    "\n",
    "kAvg = float((diff(mPCart, X, 2) + diff(mPCart, Y, 2)).subs([(X, 0), (Y, 0), (Phi, phi), (B1x, b1x), (B1y, b1y), (B2x, b2x), (B2y, b2y)])) / 2\n",
    "\n",
    "print(\"kAvg\", kAvg)\n",
    "\n",
    "hbaromega = hbar * np.sqrt(kAvg/(mStar * electronMass)) #Hamiltonian will be in units of hbaromega\n",
    "\n",
    "l = np.sqrt( hbar**2 / (mStar * electronMass * hbaromega) )\n",
    "\n",
    "Lambda = eSquaredOvere / (eps * l) / hbaromega\n",
    "\n",
    "xic  = (Lambda/4)**(1/3) * l \n",
    "\n",
    "print(\"l, hbaromega, lambda, xic, l/aM:\", l, hbaromega, Lambda, xic, (l/aM))\n",
    "\n",
    "coeffArray = np.zeros((nTaylor+1,nTaylor+1))\n",
    "\n",
    "for n in range(nTaylor+1):\n",
    "    for m in range(nTaylor+1):\n",
    "        if n+m <= nTaylor:\n",
    "            coeff = diff(diff(mPCart, X, n), Y, m).subs([(X, 0), (Y, 0), (Phi, phi), (B1x, b1x), (B1y, b1y), (B2x, b2x), (B2y, b2y)]) * (l)**(n+m) / factorial(n) / factorial(m) / hbaromega\n",
    "            coeffArray[n,m] = coeff\n",
    "            #print(np.abs((numpy.linalg.matrix_power(x, n) @ numpy.linalg.matrix_power(y, m))))\n",
    "            HspCFMatFD += np.real(complex(coeff)) * (numpy.linalg.matrix_power(x, n) @ numpy.linalg.matrix_power(y, m))\n",
    "\n",
    "print(\"coeffArray[2,2]: \", coeffArray[2,2])\n",
    "\n",
    "#HspCFMatFD += 100 * numpy.linalg.matrix_power(x, 2) #for checking consistent definition of X/Y\n",
    " \n",
    "# print(\"diag(real(HspCFMatFD)): \\n\", np.diag(np.real(HspCFMatFD)))\n",
    "# print(\"real(HspKineticMatFD): \\n\", np.real(HspKineticMatFD))\n",
    "\n",
    "toc=timeit.default_timer()\n",
    "print('crystal field time (s):', toc-tic)\n",
    "\n",
    "HspMatFD = (HspKineticMatFD + HspCFMatFD)[0:nSp,0:nSp] #trim excess\n",
    "\n",
    "# print(\"real(HspMatFD): \\n\", np.real(HspMatFD))\n",
    "\n",
    "eValsSP=np.linalg.eigvalsh(HspMatFD)\n",
    "uniqueEVals, Degens = np.unique(np.around(eValsSP,12),return_counts=True)\n",
    "summary=np.stack((uniqueEVals, Degens), axis=1)\n",
    "print('single particle eVals and degeneracy:\\n', summary[0:8])\n",
    "\n",
    "fockBasisFD = GenerateFockStates(FDBasis,NU,ND,L,mod)\n",
    "print('dim(H):', fockBasisFD.shape[0])\n",
    "tic=timeit.default_timer()\n",
    "CMatFD = generateCMatrixFDBasis(FDBasis)\n",
    "\n",
    "toc=timeit.default_timer()\n",
    "print('Coul mat elts time (s):', toc-tic)\n",
    "tic=timeit.default_timer()\n",
    "Hint, oneParticleDiffDictFD = GenerateHint(fockBasisFD,CMatFD)\n",
    "Hint *= Lambda\n",
    "\n",
    "toc=timeit.default_timer()\n",
    "print('Hint time (s):', toc-tic)\n",
    "tic=timeit.default_timer()\n",
    "H0=generateH0(nSp,fockBasisFD,HspMatFD,oneParticleDiffDictFD)\n",
    "H0EVals=np.linalg.eigvalsh(H0)\n",
    "uniqueEVals, Degens = np.unique(np.around(H0EVals,12),return_counts=True)\n",
    "summary=np.stack((uniqueEVals, Degens), axis=1)\n",
    "print('H0 evals:', summary[0:5])\n",
    "\n",
    "toc=timeit.default_timer()\n",
    "print('H0 time (s):', toc-tic)\n",
    "\n",
    "print()\n",
    "#print('H0:\\n', H0)\n",
    "Es=[]\n",
    "gsVecs=[]\n",
    "\n",
    "sparse = True\n",
    "\n",
    "tic=timeit.default_timer()\n",
    "H = Hint  + H0\n",
    "if sparse==True:\n",
    "    HSparse=scipy.sparse.csr_matrix(H)\n",
    "    #tic=timeit.default_timer()\n",
    "    eVal, eVec = ssla.eigsh(HSparse,k=1,which='SA',maxiter=1e10, return_eigenvectors=True)\n",
    "    Es.append(eVal[0])\n",
    "    gs=eVec[:,0]\n",
    "elif sparse==False:\n",
    "    eVals, eVecs=np.linalg.eigh(H)\n",
    "    #print('eVecs:\\n',eVecs[:,0])\n",
    "    uniqueEVals, Degens = np.unique(np.around(eVals,12),return_counts=True)\n",
    "    summary=np.stack((uniqueEVals, Degens), axis=1)\n",
    "    #print('many body EVals and degeneracy, lambda = %.2f :\\n'% (Lambda), summary)\n",
    "    Es.append(eVals[0])\n",
    "    gs=eVecs[:,0]\n",
    "\n",
    "rho=oneBodyDMat(gs,oneParticleDiffDictFD,FDBasis,fockBasisFD)\n",
    "natEVals, natEVecs=np.linalg.eigh(rho)\n",
    "#print('natEVals:\\n', natEVals)\n",
    "uniqueNotOEVals, NatODegens = np.unique(np.around(natEVals,12),return_counts=True)\n",
    "summary=np.flip(np.stack((uniqueNotOEVals, NatODegens), axis=1),axis=0)\n",
    "#print('natEVals and degeneracy:\\n', summary[0:5])\n",
    "gsVecs.append(gs)\n",
    "toc=timeit.default_timer()\n",
    "print('diag time:\\n', toc-tic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "PLOT CHARGE DENSITY OF MANY BODY STATE\n",
    "\"\"\"\n",
    "\n",
    "def computenofr(eState,oneParticleDiffDict,fockBasis,FDBasis,xMax,numx):\n",
    "    cutoff = 1e-7\n",
    "    xVals=np.linspace(-xMax,xMax,numx)\n",
    "    nSp=np.shape(FDBasis)[0]\n",
    "    nofr=np.zeros((numx**2),dtype=complex)\n",
    "    rhoThetaVals=np.zeros((int((numx**2)),2))\n",
    "    for i,x in enumerate(xVals):\n",
    "        for j, y in enumerate(xVals):\n",
    "            rho=np.sqrt(x**2+y**2)\n",
    "            theta=2*np.arctan(x/(y+rho)) # works for smooth definiton of theta\n",
    "            rhoThetaVals[j+i*numx]+=np.array([rho,theta])\n",
    "    for i,fi in enumerate(fockBasis):\n",
    "        if abs(eState[i])>cutoff:\n",
    "            #diagonal in fock basis\n",
    "            for sp in fi:\n",
    "                n,l=FDBasis[sp%nSp]\n",
    "                FD=fockDarwin(n,l,rhoThetaVals)\n",
    "                nofr+=abs(eState[i]*FD)**2\n",
    "    for i in oneParticleDiffDict:\n",
    "        if abs(eState[i])>cutoff: #save computation time by neglecting negligible terms\n",
    "            fi=fockBasis[i]\n",
    "            #off-diagonal in fock basis\n",
    "            for j in oneParticleDiffDict[i]:\n",
    "                if abs(eState[j])>cutoff:\n",
    "                    fj=fockBasis[j]\n",
    "                    #print('hit!')\n",
    "                    spa=set(fi).difference(set(fj)).pop() \n",
    "                    spb=set(fj).difference(set(fi)).pop()\n",
    "                    if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin\n",
    "                        if spa < spb:\n",
    "                            betweenCount=np.count_nonzero((fj>spa)&(fj<spb))\n",
    "                        elif spb < spa:\n",
    "                            betweenCount=np.count_nonzero((fj>spb)&(fj<spa))\n",
    "                        sign=(-1)**(betweenCount)\n",
    "                        ni,li= FDBasis[spa%nSp]\n",
    "                        nj,lj= FDBasis[spb%nSp]\n",
    "                        nofr+=np.conj(eState[j]*fockDarwin(ni,li,rhoThetaVals))*eState[i]*fockDarwin(nj,lj,rhoThetaVals)*sign\n",
    "    nofr=np.real(nofr.reshape(numx,numx))\n",
    "    return(nofr)\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "\n",
    "xMax = 6 # nm\n",
    "numx = 30\n",
    "\n",
    "nofr = computenofr(gsVecs[0],oneParticleDiffDictFD,fockBasisFD,FDBasis,xMax/l,numx)\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "\n",
    "ax=plt.gca()\n",
    "colormap='magma'\n",
    "\n",
    "ntot = np.sum(nofr) * (2*(xMax/l)/numx)**2\n",
    "\n",
    "print(\"integrated charge density: np.sum(nofr):\", np.sum(ntot))\n",
    "\n",
    "ATimesnofr = nofr * (A/l**2)\n",
    "\n",
    "ATimesnofrSaveName=(\"ATimesnofr\"+paramString+\"xMax%.2fnumx%dunitsnm\"%(xMax,numx))\n",
    "np.save(dataDir+ATimesnofrSaveName, ATimesnofr)\n",
    "\n",
    "imshow = ax.imshow(ATimesnofr, cmap = colormap, origin = 'lower', extent = [-xMax,xMax,-xMax,xMax], interpolation=\"gaussian\")\n",
    "#plt.contour(np.abs(nofr), origin='lower', color='blue')\n",
    "cMin = 0\n",
    "cMax = np.amax(ATimesnofr)\n",
    "imshow.set_clim(cMin, cMax)\n",
    "plt.ylabel(r'$y$ $(nm)$')\n",
    "plt.xlabel(r'$x$ $(nm)$')\n",
    "#plt.xticks([-4,-2,0,2,4])\n",
    "#plt.yticks([-4,-2,0,2,4])\n",
    "#plt.title(r'$l^2n(\\mathbf{r})$')\n",
    "cbar = plt.colorbar(imshow, fraction=0.046, pad=0.04)\n",
    "cbar.ax.set_ylabel(r'$A_{u.c.}n(\\mathbf{r})$')\n",
    "plt.savefig(figDir+'WigMolCharge'+paramString+'.pdf',bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#1D Ewald\n",
    "\n",
    "max=20\n",
    "xmax=0.8\n",
    "numx=100\n",
    "xvals=np.linspace(-xmax,xmax,numx)\n",
    "print(xvals)\n",
    "U=np.zeros_like(xvals)\n",
    "for i in range(xvals.shape[0]):\n",
    "    x=xvals[i]\n",
    "    for j in range(-max,+max+1):\n",
    "        if j != 0:\n",
    "            U[i] += 1/np.abs(x-j)\n",
    "plt.plot(xvals,U-np.amin(U))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.amax(nofr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.linspace(-4,4,50),ATimesnofr[25])\n",
    "plt.vlines([-3.64657445566/2, +3.64657445566/2],0,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xMax = 7 # nm\n",
    "numx = 100\n",
    "\n",
    "nofr = computenofr(gsVecs[0],oneParticleDiffDictFD,fockBasisFD,FDBasis,xMax/l,numx)\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "\n",
    "ax=plt.gca()\n",
    "colormap='magma'\n",
    "\n",
    "ntot = np.sum(nofr) * (2*(xMax/l)/numx)**2\n",
    "print(\"integrated charge density: np.sum(nofr):\", np.sum(ntot))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xvals=np.linspace(-xMax,xMax,numx)\n",
    "plt.plot(ATimesnofr[numx//2, :]/np.amax(ATimesnofr[numx//2, :]))\n",
    "plt.plot(np.exp(-xvals**2 / (l**2)), color=\"red\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#WIGNER MOLECULE CONVERGENCE DATA\n",
    "\n",
    "# V = 20 #meV\n",
    "# phi = 5.0 * np.pi/180 #-30 * np.pi/180\n",
    "# eps = 5\n",
    "# mStar = 1\n",
    "# N=9\n",
    "# NU=1\n",
    "# ND=1\n",
    "\n",
    "N=3, \n",
    "N=4, \n",
    "N=5, \n",
    "N=6, \n",
    "N=7, \n",
    "N=8, \n",
    "N=9, -3.1851396755\n",
    "\n",
    "phi=5 deg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#WIGNER MOLECULE CONVERGENCE DATA\n",
    "\n",
    "#third order taylor expansion\n",
    "\n",
    "# V = 20 #meV\n",
    "# phi = 5.0 * np.pi/180 #-30 * np.pi/180\n",
    "# eps = 5\n",
    "# mStar = 1\n",
    "# N=9\n",
    "# NU=1\n",
    "# ND=1\n",
    "\n",
    "N=3, -3.1298634676\n",
    "N=4, -3.1563306757\n",
    "N=5, -3.156334321\n",
    "N=6, -3.1592522373\n",
    "N=7, -3.1592532192\n",
    "N=8, -3.159548429\n",
    "N=9, -3.1595486221\n",
    "\n",
    "phi=5 deg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "PLOT CHARGE DENSITY OF MANY BODY STATE\n",
    "\"\"\"\n",
    "\n",
    "def computenofr(eState,oneParticleDiffDict,fockBasis,FDBasis,xMax,numx):\n",
    "    cutoff = 1e-7\n",
    "    xVals=np.linspace(-xMax,xMax,numx)\n",
    "    nSp=np.shape(FDBasis)[0]\n",
    "    nofr=np.zeros((numx**2),dtype=complex)\n",
    "    rhoThetaVals=np.zeros((int((numx**2)),2))\n",
    "    for i,x in enumerate(xVals):\n",
    "        for j, y in enumerate(xVals):\n",
    "            rho=np.sqrt(x**2+y**2)\n",
    "            theta=2*np.arctan(x/(y+rho)) # works for smooth definiton of theta\n",
    "            rhoThetaVals[j+i*numx]+=np.array([rho,theta])\n",
    "    for i,fi in enumerate(fockBasis):\n",
    "        if abs(eState[i])>cutoff:\n",
    "            #diagonal in fock basis\n",
    "            for sp in fi:\n",
    "                n,l=FDBasis[sp%nSp]\n",
    "                FD=fockDarwin(n,l,rhoThetaVals)\n",
    "                nofr+=abs(eState[i]*FD)**2\n",
    "    for i in oneParticleDiffDict:\n",
    "        if abs(eState[i])>cutoff: #save computation time by neglecting negligible terms\n",
    "            fi=fockBasis[i]\n",
    "            #off-diagonal in fock basis\n",
    "            for j in oneParticleDiffDict[i]:\n",
    "                if abs(eState[j])>cutoff:\n",
    "                    fj=fockBasis[j]\n",
    "                    #print('hit!')\n",
    "                    spa=set(fi).difference(set(fj)).pop() \n",
    "                    spb=set(fj).difference(set(fi)).pop()\n",
    "                    if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin\n",
    "                        if spa < spb:\n",
    "                            betweenCount=np.count_nonzero((fj>spa)&(fj<spb))\n",
    "                        elif spb < spa:\n",
    "                            betweenCount=np.count_nonzero((fj>spb)&(fj<spa))\n",
    "                        sign=(-1)**(betweenCount)\n",
    "                        ni,li= FDBasis[spa%nSp]\n",
    "                        nj,lj= FDBasis[spb%nSp]\n",
    "                        nofr+=np.conj(eState[j]*fockDarwin(ni,li,rhoThetaVals))*eState[i]*fockDarwin(nj,lj,rhoThetaVals)*sign\n",
    "    nofr=np.real(nofr.reshape(numx,numx))\n",
    "    return(nofr)\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "\n",
    "xMax = 4 # nm\n",
    "numx = 50\n",
    "\n",
    "nofr = computenofr(gsVecs[0],oneParticleDiffDictFD,fockBasisFD,FDBasis,xMax/l,numx)\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (2.00,2.00)\n",
    "\n",
    "ax=plt.gca()\n",
    "colormap='magma'\n",
    "\n",
    "ntot = np.sum(nofr) * (2*(xMax/l)/numx)**2\n",
    "\n",
    "print(\"integrated charge density: np.sum(nofr):\", np.sum(ntot))\n",
    "\n",
    "ATimesnofr = nofr * (2*(xMax/l))**2\n",
    "\n",
    "ATimesnofrSaveName=(\"ATimesnofr\"+paramString+\"xMax%.2fnumx%dunitsnm\"%(xMax,numx))\n",
    "np.save(dataDir+ATimesnofrSaveName, ATimesnofr)\n",
    "\n",
    "imshow = ax.imshow(ATimesnofr, cmap = colormap, origin = 'lower', extent = [-xMax,xMax,-xMax,xMax], interpolation=\"gaussian\")\n",
    "#plt.contour(np.abs(nofr), origin='lower', color='blue')\n",
    "cMin = 0\n",
    "cMax = np.amax(ATimesnofr)\n",
    "imshow.set_clim(cMin, cMax)\n",
    "plt.ylabel(r'$y$ $(nm)$')\n",
    "plt.xlabel(r'$x$ $(nm)$')\n",
    "plt.xticks([-4,-2,0,2,4])\n",
    "plt.yticks([-4,-2,0,2,4])\n",
    "#plt.title(r'$l^2n(\\mathbf{r})$')\n",
    "cbar = plt.colorbar(imshow, fraction=0.046, pad=0.04)\n",
    "cbar.ax.set_ylabel(r'$A_{u.c.}n(\\mathbf{r})$')\n",
    "plt.savefig(figDir+'WigMolCharge'+paramString+'.pdf',bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "a61d32b8834a82b22a55d5b09bb50782fa881bd5ee8c864a2bb525b96f9820e3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
